{
 "cells": [
  {
   "cell_type": "raw",
   "id": "ce0e08fd",
   "metadata": {},
   "source": [
    "---\n",
    "sidebar_position: 3\n",
    "keywords: [RunnableLambda, LCEL]\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fbc4bf6e",
   "metadata": {},
   "source": [
    "# How to run custom functions\n",
    "\n",
    ":::info Prerequisites\n",
    "\n",
    "This guide assumes familiarity with the following concepts:\n",
    "- [LangChain Expression Language (LCEL)](/docs/concepts/#langchain-expression-language)\n",
    "- [Chaining runnables](/docs/how_to/sequence/)\n",
    "\n",
    ":::\n",
    "\n",
    "You can use arbitrary functions as [Runnables](https://api.python.langchain.com/en/latest/runnables/langchain_core.runnables.base.Runnable.html#langchain_core.runnables.base.Runnable). This is useful for formatting or when you need functionality not provided by other LangChain components, and custom functions used as Runnables are called [`RunnableLambdas`](https://api.python.langchain.com/en/latest/runnables/langchain_core.runnables.base.RunnableLambda.html).\n",
    "\n",
    "Note that all inputs to these functions need to be a SINGLE argument. If you have a function that accepts multiple arguments, you should write a wrapper that accepts a single dict input and unpacks it into multiple argument.\n",
    "\n",
    "This guide will cover:\n",
    "\n",
    "- How to explicitly create a runnable from a custom function using the `RunnableLambda` constructor and the convenience `@chain` decorator\n",
    "- Coercion of custom functions into runnables when used in chains\n",
    "- How to accept and use run metadata in your custom function\n",
    "- How to stream with custom functions by having them return generators\n",
    "\n",
    "## Using the constructor\n",
    "\n",
    "Below, we explicitly wrap our custom logic using the `RunnableLambda` constructor:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c34d2af",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install -qU langchain langchain_openai\n",
    "\n",
    "import os\n",
    "from getpass import getpass\n",
    "\n",
    "os.environ[\"OPENAI_API_KEY\"] = getpass()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6bb221b3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='3 + 9 equals 12.', response_metadata={'token_usage': {'completion_tokens': 8, 'prompt_tokens': 14, 'total_tokens': 22}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': 'fp_c2295e73ad', 'finish_reason': 'stop', 'logprobs': None}, id='run-73728de3-e483-49e3-ad54-51bd9570e71a-0')"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from operator import itemgetter\n",
    "\n",
    "from langchain_core.prompts import ChatPromptTemplate\n",
    "from langchain_core.runnables import RunnableLambda\n",
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "\n",
    "def length_function(text):\n",
    "    return len(text)\n",
    "\n",
    "\n",
    "def _multiple_length_function(text1, text2):\n",
    "    return len(text1) * len(text2)\n",
    "\n",
    "\n",
    "def multiple_length_function(_dict):\n",
    "    return _multiple_length_function(_dict[\"text1\"], _dict[\"text2\"])\n",
    "\n",
    "\n",
    "model = ChatOpenAI()\n",
    "\n",
    "prompt = ChatPromptTemplate.from_template(\"what is {a} + {b}\")\n",
    "\n",
    "chain1 = prompt | model\n",
    "\n",
    "chain = (\n",
    "    {\n",
    "        \"a\": itemgetter(\"foo\") | RunnableLambda(length_function),\n",
    "        \"b\": {\"text1\": itemgetter(\"foo\"), \"text2\": itemgetter(\"bar\")}\n",
    "        | RunnableLambda(multiple_length_function),\n",
    "    }\n",
    "    | prompt\n",
    "    | model\n",
    ")\n",
    "\n",
    "chain.invoke({\"foo\": \"bar\", \"bar\": \"gah\"})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7926002",
   "metadata": {},
   "source": [
    "## The convenience `@chain` decorator\n",
    "\n",
    "You can also turn an arbitrary function into a chain by adding a `@chain` decorator. This is functionaly equivalent to wrapping the function in a `RunnableLambda` constructor as shown above. Here's an example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3142a516",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'The subject of the joke is the bear and his girlfriend.'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from langchain_core.output_parsers import StrOutputParser\n",
    "from langchain_core.runnables import chain\n",
    "\n",
    "prompt1 = ChatPromptTemplate.from_template(\"Tell me a joke about {topic}\")\n",
    "prompt2 = ChatPromptTemplate.from_template(\"What is the subject of this joke: {joke}\")\n",
    "\n",
    "\n",
    "@chain\n",
    "def custom_chain(text):\n",
    "    prompt_val1 = prompt1.invoke({\"topic\": text})\n",
    "    output1 = ChatOpenAI().invoke(prompt_val1)\n",
    "    parsed_output1 = StrOutputParser().invoke(output1)\n",
    "    chain2 = prompt2 | ChatOpenAI() | StrOutputParser()\n",
    "    return chain2.invoke({\"joke\": parsed_output1})\n",
    "\n",
    "\n",
    "custom_chain.invoke(\"bears\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4728ddd9-914d-42ce-ae9b-72c9ce8ec940",
   "metadata": {},
   "source": [
    "Above, the `@chain` decorator is used to convert `custom_chain` into a runnable, which we invoke with the `.invoke()` method.\n",
    "\n",
    "If you are using a tracing with [LangSmith](https://docs.smith.langchain.com/), you should see a `custom_chain` trace in there, with the calls to OpenAI nested underneath.\n",
    "\n",
    "## Automatic coercion in chains\n",
    "\n",
    "When using custom functions in chains with the pipe operator (`|`), you can omit the `RunnableLambda` or `@chain` constructor and rely on coercion. Here's a simple example with a function that takes the output from the model and returns the first five letters of it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "5ab39a87",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Once '"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prompt = ChatPromptTemplate.from_template(\"tell me a story about {topic}\")\n",
    "\n",
    "model = ChatOpenAI()\n",
    "\n",
    "chain_with_coerced_function = prompt | model | (lambda x: x.content[:5])\n",
    "\n",
    "chain_with_coerced_function.invoke({\"topic\": \"bears\"})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c9a481d1",
   "metadata": {},
   "source": [
    "Note that we didn't need to wrap the custom function `(lambda x: x.content[:5])` in a `RunnableLambda` constructor because the `model` on the left of the pipe operator is already a Runnable. The custom function is **coerced** into a runnable. See [this section](/docs/how_to/sequence/#coercion) for more information.\n",
    "\n",
    "## Passing run metadata\n",
    "\n",
    "Runnable lambdas can optionally accept a [RunnableConfig](https://api.python.langchain.com/en/latest/runnables/langchain_core.runnables.config.RunnableConfig.html#langchain_core.runnables.config.RunnableConfig) parameter, which they can use to pass callbacks, tags, and other configuration information to nested runs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ff0daf0c-49dd-4d21-9772-e5fa133c5f36",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'foo': 'bar'}\n",
      "Tokens Used: 62\n",
      "\tPrompt Tokens: 56\n",
      "\tCompletion Tokens: 6\n",
      "Successful Requests: 1\n",
      "Total Cost (USD): $9.6e-05\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "\n",
    "from langchain_core.runnables import RunnableConfig\n",
    "\n",
    "\n",
    "def parse_or_fix(text: str, config: RunnableConfig):\n",
    "    fixing_chain = (\n",
    "        ChatPromptTemplate.from_template(\n",
    "            \"Fix the following text:\\n\\n```text\\n{input}\\n```\\nError: {error}\"\n",
    "            \" Don't narrate, just respond with the fixed data.\"\n",
    "        )\n",
    "        | model\n",
    "        | StrOutputParser()\n",
    "    )\n",
    "    for _ in range(3):\n",
    "        try:\n",
    "            return json.loads(text)\n",
    "        except Exception as e:\n",
    "            text = fixing_chain.invoke({\"input\": text, \"error\": e}, config)\n",
    "    return \"Failed to parse\"\n",
    "\n",
    "\n",
    "from langchain_community.callbacks import get_openai_callback\n",
    "\n",
    "with get_openai_callback() as cb:\n",
    "    output = RunnableLambda(parse_or_fix).invoke(\n",
    "        \"{foo: bar}\", {\"tags\": [\"my-tag\"], \"callbacks\": [cb]}\n",
    "    )\n",
    "    print(output)\n",
    "    print(cb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1a5e709e-9d75-48c7-bb9c-503251990505",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'foo': 'bar'}\n",
      "Tokens Used: 62\n",
      "\tPrompt Tokens: 56\n",
      "\tCompletion Tokens: 6\n",
      "Successful Requests: 1\n",
      "Total Cost (USD): $9.6e-05\n"
     ]
    }
   ],
   "source": [
    "from langchain_community.callbacks import get_openai_callback\n",
    "\n",
    "with get_openai_callback() as cb:\n",
    "    output = RunnableLambda(parse_or_fix).invoke(\n",
    "        \"{foo: bar}\", {\"tags\": [\"my-tag\"], \"callbacks\": [cb]}\n",
    "    )\n",
    "    print(output)\n",
    "    print(cb)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "922b48bd",
   "metadata": {},
   "source": [
    "## Streaming\n",
    "\n",
    ":::{.callout-note}\n",
    "[RunnableLambda](https://api.python.langchain.com/en/latest/runnables/langchain_core.runnables.base.RunnableLambda.html) is best suited for code that does not need to support streaming. If you need to support streaming (i.e., be able to operate on chunks of inputs and yield chunks of outputs), use [RunnableGenerator](https://api.python.langchain.com/en/latest/runnables/langchain_core.runnables.base.RunnableGenerator.html) instead as in the example below.\n",
    ":::\n",
    "\n",
    "You can use generator functions (ie. functions that use the `yield` keyword, and behave like iterators) in a chain.\n",
    "\n",
    "The signature of these generators should be `Iterator[Input] -> Iterator[Output]`. Or for async generators: `AsyncIterator[Input] -> AsyncIterator[Output]`.\n",
    "\n",
    "These are useful for:\n",
    "- implementing a custom output parser\n",
    "- modifying the output of a previous step, while preserving streaming capabilities\n",
    "\n",
    "Here's an example of a custom output parser for comma-separated lists. First, we create a chain that generates such a list as text:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "29f55c38",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lion, tiger, wolf, gorilla, panda"
     ]
    }
   ],
   "source": [
    "from typing import Iterator, List\n",
    "\n",
    "prompt = ChatPromptTemplate.from_template(\n",
    "    \"Write a comma-separated list of 5 animals similar to: {animal}. Do not include numbers\"\n",
    ")\n",
    "\n",
    "str_chain = prompt | model | StrOutputParser()\n",
    "\n",
    "for chunk in str_chain.stream({\"animal\": \"bear\"}):\n",
    "    print(chunk, end=\"\", flush=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46345323",
   "metadata": {},
   "source": [
    "Next, we define a custom function that will aggregate the currently streamed output and yield it when the model generates the next comma in the list:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f08b8a5b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['lion']\n",
      "['tiger']\n",
      "['wolf']\n",
      "['gorilla']\n",
      "['raccoon']\n"
     ]
    }
   ],
   "source": [
    "# This is a custom parser that splits an iterator of llm tokens\n",
    "# into a list of strings separated by commas\n",
    "def split_into_list(input: Iterator[str]) -> Iterator[List[str]]:\n",
    "    # hold partial input until we get a comma\n",
    "    buffer = \"\"\n",
    "    for chunk in input:\n",
    "        # add current chunk to buffer\n",
    "        buffer += chunk\n",
    "        # while there are commas in the buffer\n",
    "        while \",\" in buffer:\n",
    "            # split buffer on comma\n",
    "            comma_index = buffer.index(\",\")\n",
    "            # yield everything before the comma\n",
    "            yield [buffer[:comma_index].strip()]\n",
    "            # save the rest for the next iteration\n",
    "            buffer = buffer[comma_index + 1 :]\n",
    "    # yield the last chunk\n",
    "    yield [buffer.strip()]\n",
    "\n",
    "\n",
    "list_chain = str_chain | split_into_list\n",
    "\n",
    "for chunk in list_chain.stream({\"animal\": \"bear\"}):\n",
    "    print(chunk, flush=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a5adb69",
   "metadata": {},
   "source": [
    "Invoking it gives a full array of values:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "9ea4ddc6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['lion', 'tiger', 'wolf', 'gorilla', 'raccoon']"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list_chain.invoke({\"animal\": \"bear\"})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "96e320ed",
   "metadata": {},
   "source": [
    "## Async version\n",
    "\n",
    "If you are working in an `async` environment, here is an `async` version of the above example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "569dbbef",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['lion']\n",
      "['tiger']\n",
      "['wolf']\n",
      "['gorilla']\n",
      "['panda']\n"
     ]
    }
   ],
   "source": [
    "from typing import AsyncIterator\n",
    "\n",
    "\n",
    "async def asplit_into_list(\n",
    "    input: AsyncIterator[str],\n",
    ") -> AsyncIterator[List[str]]:  # async def\n",
    "    buffer = \"\"\n",
    "    async for (\n",
    "        chunk\n",
    "    ) in input:  # `input` is a `async_generator` object, so use `async for`\n",
    "        buffer += chunk\n",
    "        while \",\" in buffer:\n",
    "            comma_index = buffer.index(\",\")\n",
    "            yield [buffer[:comma_index].strip()]\n",
    "            buffer = buffer[comma_index + 1 :]\n",
    "    yield [buffer.strip()]\n",
    "\n",
    "\n",
    "list_chain = str_chain | asplit_into_list\n",
    "\n",
    "async for chunk in list_chain.astream({\"animal\": \"bear\"}):\n",
    "    print(chunk, flush=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "3a650482",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['lion', 'tiger', 'wolf', 'gorilla', 'panda']"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "await list_chain.ainvoke({\"animal\": \"bear\"})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3306ac3b",
   "metadata": {},
   "source": [
    "## Next steps\n",
    "\n",
    "Now you've learned a few different ways to use custom logic within your chains, and how to implement streaming.\n",
    "\n",
    "To learn more, see the other how-to guides on runnables in this section."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
